# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Sequence
from typing import Tuple

import paddle
from paddle import nn
from typeguard import typechecked

from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.modules.losses import ssim
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask


class FastSpeech2MIDI(FastSpeech2):
    """The Fastspeech2 module of DiffSinger.
    """

    @typechecked
    def __init__(
            self,
            # fastspeech2 network structure related
            idim: int,
            odim: int,
            fastspeech2_params: Dict[str, Any],
            # note emb
            note_num: int=300,
            # is_slur emb
            is_slur_num: int=2,
            use_energy_pred: bool=False,
            use_postnet: bool=False, ):
        """Initialize FastSpeech2 module for svs.
        Args:
            fastspeech2_params (Dict):
                The config of FastSpeech2 module on DiffSinger model
            note_num (Optional[int]): 
                Number of note. If not None, assume that the
                note_ids will be provided as the input and use note_embedding_table.
            is_slur_num (Optional[int]): 
                Number of note. If not None, assume that the
                is_slur_ids will be provided as the input
    
        """
        super().__init__(idim=idim, odim=odim, **fastspeech2_params)
        self.use_energy_pred = use_energy_pred
        self.use_postnet = use_postnet
        if not self.use_postnet:
            self.postnet = None

        self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
            "adim"]

        # note_ embed
        self.note_embedding_table = nn.Embedding(
            num_embeddings=note_num,
            embedding_dim=self.note_embed_dim,
            padding_idx=self.padding_idx)
        self.note_dur_layer = nn.Linear(1, self.note_embed_dim)

        # slur embed
        self.is_slur_embedding_table = nn.Embedding(
            num_embeddings=is_slur_num,
            embedding_dim=self.is_slur_embed_dim,
            padding_idx=self.padding_idx)

    def forward(
            self,
            text: paddle.Tensor,
            note: paddle.Tensor,
            note_dur: paddle.Tensor,
            is_slur: paddle.Tensor,
            text_lengths: paddle.Tensor,
            speech: paddle.Tensor,
            speech_lengths: paddle.Tensor,
            durations: paddle.Tensor,
            pitch: paddle.Tensor,
            energy: paddle.Tensor,
            spk_emb: paddle.Tensor=None,
            spk_id: paddle.Tensor=None,
    ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
        """Calculate forward propagation.

        Args:
            text(Tensor(int64)): 
                Batch of padded token (phone) ids (B, Tmax).
            note(Tensor(int64)): 
                Batch of padded note (element in music score) ids (B, Tmax).
            note_dur(Tensor(float32)): 
                Batch of padded note durations in seconds (element in music score) (B, Tmax).
            is_slur(Tensor(int64)): 
                Batch of padded slur (element in music score) ids (B, Tmax).
            text_lengths(Tensor(int64)): 
                Batch of phone lengths of each input (B,).
            speech(Tensor[float32]): 
                Batch of padded target features (e.g. mel) (B, Lmax, odim).
            speech_lengths(Tensor(int64)): 
                Batch of the lengths of each target features (B,).
            durations(Tensor(int64)): 
                Batch of padded token durations in frame (B, Tmax).
            pitch(Tensor[float32]): 
                Batch of padded frame-averaged pitch (B, Lmax, 1).
            energy(Tensor[float32]): 
                Batch of padded frame-averaged energy (B, Lmax, 1).
            spk_emb(Tensor[float32], optional): 
                Batch of speaker embeddings (B, spk_embed_dim).
            spk_id(Tnesor[int64], optional(int64)): 
                Batch of speaker ids (B,)

        Returns:

        """
        xs = paddle.cast(text, 'int64')
        note = paddle.cast(note, 'int64')
        note_dur = paddle.cast(note_dur, 'float32')
        is_slur = paddle.cast(is_slur, 'int64')
        ilens = paddle.cast(text_lengths, 'int64')
        olens = paddle.cast(speech_lengths, 'int64')
        ds = paddle.cast(durations, 'int64')
        ps = pitch
        es = energy
        ys = speech
        olens = speech_lengths
        if spk_id is not None:
            spk_id = paddle.cast(spk_id, 'int64')
        # forward propagation
        before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
            xs=xs,
            note=note,
            note_dur=note_dur,
            is_slur=is_slur,
            ilens=ilens,
            olens=olens,
            ds=ds,
            ps=ps,
            es=es,
            is_inference=False,
            spk_emb=spk_emb,
            spk_id=spk_id, )
        # modify mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens - olens % self.reduction_factor
            max_olen = max(olens)
            ys = ys[:, :max_olen]

        return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits

    def _forward(
            self,
            xs: paddle.Tensor,
            note: paddle.Tensor,
            note_dur: paddle.Tensor,
            is_slur: paddle.Tensor,
            ilens: paddle.Tensor,
            olens: paddle.Tensor=None,
            ds: paddle.Tensor=None,
            ps: paddle.Tensor=None,
            es: paddle.Tensor=None,
            is_inference: bool=False,
            is_train_diffusion: bool=False,
            return_after_enc=False,
            alpha: float=1.0,
            spk_emb=None,
            spk_id=None, ) -> Sequence[paddle.Tensor]:

        before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
        # forward encoder
        masks = self._source_mask(ilens)
        note_emb = self.note_embedding_table(note)
        note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
        is_slur_emb = self.is_slur_embedding_table(is_slur)

        # (B, Tmax, adim)
        hs, _ = self.encoder(
            xs=xs,
            masks=masks,
            note_emb=note_emb,
            note_dur_emb=note_dur_emb,
            is_slur_emb=is_slur_emb, )

        if self.spk_num and self.enable_speaker_classifier and not is_inference:
            hs_for_spk_cls = self.grad_reverse(hs)
            spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
        else:
            spk_logits = None

        # integrate speaker embedding
        if self.spk_embed_dim is not None:
            # spk_emb has a higher priority than spk_id
            if spk_emb is not None:
                hs = self._integrate_with_spk_embed(hs, spk_emb)
            elif spk_id is not None:
                spk_emb = self.spk_embedding_table(spk_id)
                hs = self._integrate_with_spk_embed(hs, spk_emb)

        # forward duration predictor (phone-level) and variance predictors (frame-level)
        d_masks = make_pad_mask(ilens)
        if olens is not None:
            pitch_masks = make_pad_mask(olens).unsqueeze(-1)
        else:
            pitch_masks = None

        # inference for decoder input for diffusion
        if is_train_diffusion:
            hs = self.length_regulator(hs, ds, is_inference=False)
            p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
            p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
                (0, 2, 1))
            hs += p_embs
            if self.use_energy_pred:
                e_outs = self.energy_predictor(hs.detach(), pitch_masks)
                e_embs = self.energy_embed(
                    e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
                hs += e_embs

        elif is_inference:
            # (B, Tmax)
            if ds is not None:
                d_outs = ds
            else:
                d_outs = self.duration_predictor.inference(hs, d_masks)

            # (B, Lmax, adim)
            hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)

            if ps is not None:
                p_outs = ps
            else:
                if self.stop_gradient_from_pitch_predictor:
                    p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
                else:
                    p_outs = self.pitch_predictor(hs, pitch_masks)
            p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
                (0, 2, 1))
            hs += p_embs

            if self.use_energy_pred:
                if es is not None:
                    e_outs = es
                else:
                    if self.stop_gradient_from_energy_predictor:
                        e_outs = self.energy_predictor(hs.detach(), pitch_masks)
                    else:
                        e_outs = self.energy_predictor(hs, pitch_masks)
                e_embs = self.energy_embed(
                    e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
                hs += e_embs

        # training
        else:
            d_outs = self.duration_predictor(hs, d_masks)
            # (B, Lmax, adim)
            hs = self.length_regulator(hs, ds, is_inference=False)
            if self.stop_gradient_from_pitch_predictor:
                p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
            else:
                p_outs = self.pitch_predictor(hs, pitch_masks)
            p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
                (0, 2, 1))
            hs += p_embs

            if self.use_energy_pred:
                if self.stop_gradient_from_energy_predictor:
                    e_outs = self.energy_predictor(hs.detach(), pitch_masks)
                else:
                    e_outs = self.energy_predictor(hs, pitch_masks)
                e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
                    (0, 2, 1))
                hs += e_embs

        # forward decoder
        if olens is not None and not is_inference:
            if self.reduction_factor > 1:
                olens_in = paddle.to_tensor(
                    [olen // self.reduction_factor for olen in olens.numpy()])
            else:
                olens_in = olens
            # (B, 1, T)
            h_masks = self._source_mask(olens_in)
        else:
            h_masks = None

        if return_after_enc:
            return hs, h_masks

        if self.decoder_type == 'cnndecoder':
            # remove output masks for dygraph to static graph
            zs = self.decoder(hs, h_masks)
            before_outs = zs
        else:
            # (B, Lmax, adim)
            zs, _ = self.decoder(hs, h_masks)
            # (B, Lmax, odim)
            before_outs = self.feat_out(zs).reshape(
                (paddle.shape(zs)[0], -1, self.odim))

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(
                before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))

        return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits

    def encoder_infer(
            self,
            text: paddle.Tensor,
            note: paddle.Tensor,
            note_dur: paddle.Tensor,
            is_slur: paddle.Tensor,
            alpha: float=1.0,
            spk_emb=None,
            spk_id=None,
    ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        xs = paddle.cast(text, 'int64').unsqueeze(0)
        note = paddle.cast(note, 'int64').unsqueeze(0)
        note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
        is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
        # setup batch axis
        ilens = paddle.shape(xs)[1]

        if spk_emb is not None:
            spk_emb = spk_emb.unsqueeze(0)

        # (1, L, odim)
        # use *_ to avoid bug in dygraph to static graph    
        hs, _ = self._forward(
            xs=xs,
            note=note,
            note_dur=note_dur,
            is_slur=is_slur,
            ilens=ilens,
            is_inference=True,
            return_after_enc=True,
            alpha=alpha,
            spk_emb=spk_emb,
            spk_id=spk_id, )
        return hs

    # get encoder output for diffusion training
    def encoder_infer_batch(
            self,
            text: paddle.Tensor,
            note: paddle.Tensor,
            note_dur: paddle.Tensor,
            is_slur: paddle.Tensor,
            text_lengths: paddle.Tensor,
            speech_lengths: paddle.Tensor,
            ds: paddle.Tensor=None,
            ps: paddle.Tensor=None,
            es: paddle.Tensor=None,
            alpha: float=1.0,
            spk_emb=None,
            spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:

        xs = paddle.cast(text, 'int64')
        note = paddle.cast(note, 'int64')
        note_dur = paddle.cast(note_dur, 'float32')
        is_slur = paddle.cast(is_slur, 'int64')
        ilens = paddle.cast(text_lengths, 'int64')
        olens = paddle.cast(speech_lengths, 'int64')

        if spk_emb is not None:
            spk_emb = spk_emb.unsqueeze(0)

        # (1, L, odim)
        # use *_ to avoid bug in dygraph to static graph    
        hs, h_masks = self._forward(
            xs=xs,
            note=note,
            note_dur=note_dur,
            is_slur=is_slur,
            ilens=ilens,
            olens=olens,
            ds=ds,
            ps=ps,
            es=es,
            return_after_enc=True,
            is_train_diffusion=True,
            alpha=alpha,
            spk_emb=spk_emb,
            spk_id=spk_id, )
        return hs, h_masks

    def inference(
            self,
            text: paddle.Tensor,
            note: paddle.Tensor,
            note_dur: paddle.Tensor,
            is_slur: paddle.Tensor,
            durations: paddle.Tensor=None,
            pitch: paddle.Tensor=None,
            energy: paddle.Tensor=None,
            alpha: float=1.0,
            use_teacher_forcing: bool=False,
            spk_emb=None,
            spk_id=None,
    ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Generate the sequence of features given the sequences of characters.

        Args:
            text(Tensor(int64)): 
                Input sequence of characters (T,).
            note(Tensor(int64)): 
                Input note (element in music score) ids (T,).
            note_dur(Tensor(float32)): 
               Input note durations in seconds (element in music score) (T,).
            is_slur(Tensor(int64)): 
                Input slur (element in music score) ids (T,).
            durations(Tensor, optional (int64)): 
                Groundtruth of duration (T,).
            pitch(Tensor, optional): 
                Groundtruth of token-averaged pitch (T, 1).
            energy(Tensor, optional): 
                Groundtruth of token-averaged energy (T, 1).
            alpha(float, optional): 
                Alpha to control the speed.
            use_teacher_forcing(bool, optional): 
                Whether to use teacher forcing.
                If true, groundtruth of duration, pitch and energy will be used.
            spk_emb(Tensor, optional, optional): 
                peaker embedding vector (spk_embed_dim,). (Default value = None)
            spk_id(Tensor, optional(int64), optional): 
                spk ids (1,). (Default value = None)

        Returns:

        """
        xs = paddle.cast(text, 'int64').unsqueeze(0)
        note = paddle.cast(note, 'int64').unsqueeze(0)
        note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
        is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
        d, p, e = durations, pitch, energy
        # setup batch axis
        ilens = paddle.shape(xs)[1]

        if spk_emb is not None:
            spk_emb = spk_emb.unsqueeze(0)

        if use_teacher_forcing:
            # use groundtruth of duration, pitch, and energy
            ds = d.unsqueeze(0) if d is not None else None
            ps = p.unsqueeze(0) if p is not None else None
            es = e.unsqueeze(0) if e is not None else None

            # (1, L, odim)
            _, outs, d_outs, p_outs, e_outs, _ = self._forward(
                xs=xs,
                note=note,
                note_dur=note_dur,
                is_slur=is_slur,
                ilens=ilens,
                ds=ds,
                ps=ps,
                es=es,
                spk_emb=spk_emb,
                spk_id=spk_id,
                is_inference=True)
        else:
            # (1, L, odim)
            _, outs, d_outs, p_outs, e_outs, _ = self._forward(
                xs=xs,
                note=note,
                note_dur=note_dur,
                is_slur=is_slur,
                ilens=ilens,
                is_inference=True,
                alpha=alpha,
                spk_emb=spk_emb,
                spk_id=spk_id, )

        if e_outs is None:
            e_outs = [None]

        return outs[0], d_outs[0], p_outs[0], e_outs[0]


class FastSpeech2MIDILoss(FastSpeech2Loss):
    """Loss function module for DiffSinger."""

    @typechecked
    def __init__(self, use_masking: bool=True,
                 use_weighted_masking: bool=False):
        """Initialize feed-forward Transformer loss module.
        Args:
            use_masking (bool): 
                Whether to apply masking for padded part in loss calculation.
            use_weighted_masking (bool): 
                Whether to weighted masking in loss calculation.
        """
        super().__init__(use_masking, use_weighted_masking)

    def forward(
            self,
            after_outs: paddle.Tensor,
            before_outs: paddle.Tensor,
            d_outs: paddle.Tensor,
            p_outs: paddle.Tensor,
            e_outs: paddle.Tensor,
            ys: paddle.Tensor,
            ds: paddle.Tensor,
            ps: paddle.Tensor,
            es: paddle.Tensor,
            ilens: paddle.Tensor,
            olens: paddle.Tensor,
            spk_logits: paddle.Tensor=None,
            spk_ids: paddle.Tensor=None,
    ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
               paddle.Tensor, ]:
        """Calculate forward propagation.

        Args:
            after_outs(Tensor):  
                Batch of outputs after postnets (B, Lmax, odim).
            before_outs(Tensor): 
                Batch of outputs before postnets (B, Lmax, odim).
            d_outs(Tensor): 
                Batch of outputs of duration predictor (B, Tmax).
            p_outs(Tensor): 
                Batch of outputs of pitch predictor (B, Lmax, 1).
            e_outs(Tensor): 
                Batch of outputs of energy predictor (B, Lmax, 1).
            ys(Tensor): 
                Batch of target features (B, Lmax, odim).
            ds(Tensor): 
                Batch of durations (B, Tmax).
            ps(Tensor): 
                Batch of target frame-averaged pitch (B, Lmax, 1).
            es(Tensor): 
                Batch of target frame-averaged energy (B, Lmax, 1).
            ilens(Tensor): 
                Batch of the lengths of each input (B,).
            olens(Tensor): 
                Batch of the lengths of each target (B,).
            spk_logits(Option[Tensor]):
                Batch of outputs after speaker classifier (B, Lmax, num_spk)
            spk_ids(Option[Tensor]):
                Batch of target spk_id (B,)
        

        Returns:

        
        """
        l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0

        # apply mask to remove padded part
        if self.use_masking:
            # make feature for ssim loss
            out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
            before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
            if not paddle.equal_all(after_outs, before_outs):
                after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
            ys_ssim = masked_fill(ys, out_pad_masks, 0.0)

            out_masks = make_non_pad_mask(olens).unsqueeze(-1)
            before_outs = before_outs.masked_select(
                out_masks.broadcast_to(before_outs.shape))
            if not paddle.equal_all(after_outs, before_outs):
                after_outs = after_outs.masked_select(
                    out_masks.broadcast_to(after_outs.shape))
            ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
            duration_masks = make_non_pad_mask(ilens)
            d_outs = d_outs.masked_select(
                duration_masks.broadcast_to(d_outs.shape))
            ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
            pitch_masks = out_masks
            p_outs = p_outs.masked_select(
                pitch_masks.broadcast_to(p_outs.shape))
            ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
            if e_outs is not None:
                e_outs = e_outs.masked_select(
                    pitch_masks.broadcast_to(e_outs.shape))
                es = es.masked_select(pitch_masks.broadcast_to(es.shape))

            if spk_logits is not None and spk_ids is not None:
                batch_size = spk_ids.shape[0]
                spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
                                                   None)
                spk_logits = paddle.reshape(spk_logits,
                                            [-1, spk_logits.shape[-1]])
                mask_index = spk_logits.abs().sum(axis=1) != 0
                spk_ids = spk_ids[mask_index]
                spk_logits = spk_logits[mask_index]

        # calculate loss
        l1_loss = self.l1_criterion(before_outs, ys)
        ssim_loss = 1.0 - ssim(
            before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
        if not paddle.equal_all(after_outs, before_outs):
            l1_loss += self.l1_criterion(after_outs, ys)
            ssim_loss += (
                1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
        l1_loss = l1_loss * 0.5
        ssim_loss = ssim_loss * 0.5

        duration_loss = self.duration_criterion(d_outs, ds)
        pitch_loss = self.l1_criterion(p_outs, ps)
        if e_outs is not None:
            energy_loss = self.l1_criterion(e_outs, es)

        if spk_logits is not None and spk_ids is not None:
            speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size

        # make weighted mask and apply it
        if self.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1)
            out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
                dtype=paddle.float32).sum(
                    axis=1, keepdim=True)
            out_weights /= ys.shape[0] * ys.shape[2]
            duration_masks = make_non_pad_mask(ilens)
            duration_weights = (duration_masks.cast(dtype=paddle.float32) /
                                duration_masks.cast(dtype=paddle.float32).sum(
                                    axis=1, keepdim=True))
            duration_weights /= ds.shape[0]

            # apply weight
            l1_loss = l1_loss.multiply(out_weights)
            l1_loss = l1_loss.masked_select(
                out_masks.broadcast_to(l1_loss.shape)).sum()
            ssim_loss = ssim_loss.multiply(out_weights)
            ssim_loss = ssim_loss.masked_select(
                out_masks.broadcast_to(ssim_loss.shape)).sum()
            duration_loss = (duration_loss.multiply(duration_weights)
                             .masked_select(duration_masks).sum())
            pitch_masks = out_masks
            pitch_weights = out_weights
            pitch_loss = pitch_loss.multiply(pitch_weights)
            pitch_loss = pitch_loss.masked_select(
                pitch_masks.broadcast_to(pitch_loss.shape)).sum()
            if e_outs is not None:
                energy_loss = energy_loss.multiply(pitch_weights)
                energy_loss = energy_loss.masked_select(
                    pitch_masks.broadcast_to(energy_loss.shape)).sum()

        return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
